import matplotlib.pyplot as plt
import numpy as np

# --- Predictor variable names grouped by chart ---
treatment_vars = ["Human_Only", "Tactical_AI", "Strategic_AI"]
covariate_vars = ["Education", "Career_Length", "Cyber_Know_How", "FP_Know_How", "Risk_Aversion"]

# --- B coefficients and SEs for DV1 and DV2 ---
B_DV1 = {
    "Human_Only": 0.858,
    "Tactical_AI": 1.434,
    "Strategic_AI": 0.396,
    "Education": -0.007,
    "Career_Length": -0.041,
    "Cyber_Know_How": -0.203,
    "FP_Know_How": 0.066,
    "Risk_Aversion": 0.555
}

SE_DV1 = {
    "Human_Only": 0.235,
    "Tactical_AI": 0.240,
    "Strategic_AI": 0.279,
    "Education": 0.049,
    "Career_Length": 0.017,
    "Cyber_Know_How": 0.100,
    "FP_Know_How": 0.176,
    "Risk_Aversion": 0.197
}

B_DV2 = {
    "Human_Only": 0.630,
    "Tactical_AI": 1.000,
    "Strategic_AI": 0.472,
    "Education": -0.057,
    "Career_Length": -0.028,
    "Cyber_Know_How": -0.311,
    "FP_Know_How": 0.393,
    "Risk_Aversion": 1.740
}

SE_DV2 = {
    "Human_Only": 0.258,
    "Tactical_AI": 0.250,
    "Strategic_AI": 0.288,
    "Education": 0.047,
    "Career_Length": 0.016,
    "Cyber_Know_How": 0.097,
    "FP_Know_How": 0.164,
    "Risk_Aversion": 0.196
}


def plot_group(variables, title, filename):
    x = np.arange(len(variables))
    width = 0.35

    b1 = [B_DV1[v] for v in variables]
    se1 = [SE_DV1[v] for v in variables]
    b2 = [B_DV2[v] for v in variables]
    se2 = [SE_DV2[v] for v in variables]

    fig, ax = plt.subplots(figsize=(10, 5))

    ax.errorbar(x - width / 2, b1, yerr=[1.96 * s for s in se1],
                fmt='o', color='red', capsize=5, label='DV #1')

    ax.errorbar(x + width / 2, b2, yerr=[1.96 * s for s in se2],
                fmt='o', color='blue', capsize=5, label='DV #2')

    ax.axhline(0, color='red', linestyle='dotted')
    ax.set_xticks(x)
    ax.set_xticklabels(variables, fontsize=10)
    ax.set_ylabel("Mean Treatment Effects")
    ax.set_title(title)
    ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=2, frameon=False)

    for spine in ['top', 'right', 'left']:
        ax.spines[spine].set_visible(False)

    ax.grid(axis='y', linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(filename, dpi=300)
    plt.show()


# --- Run the plots ---
plot_group(treatment_vars, "Treatment Effects on DVs", "treatment_effects.png")
plt.close()

plot_group(covariate_vars, "Covariate Effects on DVs", "covariate_effects.png")
plt.close()
